library(tidyverse)
# remotes::install_github("nickbrazeau/discent", ref = "develop")
library(discent)
source("R/polysim_wrappers.R")
source("R/discent_wrappers.R")
source("R/utils.R")
Model conceptualization: \[ r_{ij} = \left(\frac{f_i + f_j}{2}\right) exp\left\{ \frac{-d_{ij}}{M} \right\} \]
Coding Decisions: 1. M and F parameters have different learning rates (uncoupled) 2. M must be positive 3. ADAM gradient descent optimizer
Using polySimIBD have five frameworks of simulation,
which for each we ran 100 independent realizations (i.e. same migration,
same Ne = different seed). For each framework, there are 25 demes laid
out on a 5x5 square plane. Frameworks are as follows:
Goal to find the best start parameters for each framework. Selected
first realization and made assumption that start parameters would be
consistent across realizations within each framework. However, this also
provides the opportunity to determine how much start parameters effect
final estimates of DISCs (called final_fis in code) and
final global migration rates (called final_ms). Our search
grid lookover the following 11220 combinations:
# look up tables
fstartsvec <- seq(from = 0.1, to = 0.9, by = 0.05)
mstartsvec <- 10^-seq(1,6)
flearnsvec <- 10^-seq(1,10)
mlearnsvec <- 10^-seq(5,15)
search_grid <- expand.grid(fstartsvec, mstartsvec, flearnsvec, mlearnsvec)
colnames(search_grid) <- c("fstart", "mstart", "f_learn", "m_learn")
#............................................................
# Part 1: Model Stability
#...........................................................
# read in model framework
migmatdf <- readRDS("simdata/simulation_setup/inputs/migmat_framework.RDS")
coordgrid <- readRDS("simdata/simulation_setup/inputs/squarecoords.rds")
# read in search grid
search_grid_fullraw <- readRDS("disc_results/search_grid_full_for_discdat.RDS")
search_grid_full <- search_grid_fullraw %>%
dplyr::mutate(finalcost = purrr::map_dbl(discret, extract_final_cost),
final_fs = purrr::map(discret, get_fs),
final_ms = purrr::map_dbl(discret, get_ms))
search_grid_full %>%
ggplot() +
geom_boxplot(aes(x = modname, y = finalcost)) +
theme_linedraw() +
ggtitle("All Costs over Model Frameworks")
We will downsample to reasonable costs.
search_grid_full <- search_grid_full %>%
dplyr::filter(finalcost < 1e5)
search_grid_full %>%
tidyr::unnest(cols = "final_fs") %>%
ggplot() +
geom_point(aes(x = finalcost, y = Final_Fis,
color = rep)) +
theme_linedraw() +
scale_color_viridis_c() +
facet_wrap(~modname) +
ggtitle("Costs vs Final Fs") +
theme(legend.position = "none",
axis.text.x = element_text(angle = 90))
search_grid_full %>%
tidyr::unnest(cols = "final_ms") %>%
ggplot() +
geom_point(aes(x = finalcost, y = final_ms,
color = rep)) +
theme_linedraw() +
scale_color_viridis_c() +
facet_wrap(~modname) +
ggtitle("Costs vs Final Ms") +
theme(legend.position = "none",
axis.text.x = element_text(angle = 90))
search_grid_full %>%
tidyr::unnest(cols = c("final_fs", "final_ms")) %>%
ggplot() +
geom_point(aes(x = Final_Fis, y = final_ms,
color = finalcost)) +
theme_linedraw() +
scale_color_viridis_c() +
facet_wrap(~modname) +
ggtitle("Final Fs vs Final Ms") +
theme(axis.text.x = element_text(angle = 90))
Note, here we are running the same start parameters identified from search grid above for our 100 realizations for each model framework. Therefore, there are 5x100 final global M estimations and 5x25x100 final Fi estimations.
# purely for memory
rm(search_grid_fullraw)
rm(search_grid_full)
gc()
#............................................................
# Part 2: Model Accuracy
#...........................................................
discretsraw <- readRDS("disc_results/final_discresults_for_discdat.RDS")
discrets <- discretsraw %>%
dplyr::mutate(finalcost = purrr::map_dbl(discret, extract_final_cost),
final_fs = purrr::map(discret, get_fs),
final_ms = purrr::map_dbl(discret, get_ms))
discrets %>%
ggplot() +
geom_boxplot(aes(x = modname, y = finalcost)) +
theme_linedraw() +
ggtitle("Frameworks vs Costs") +
theme(legend.position = "none",
axis.text.x = element_text(angle = 90))
discrets %>%
ggplot() +
geom_boxplot(aes(x = modname, y = final_ms)) +
theme_linedraw() +
ggtitle("Frameworks vs Final Ms") +
theme(legend.position = "none",
axis.text.x = element_text(angle = 90))
discrets %>%
ggplot() +
geom_point(aes(x = finalcost, y = final_ms), alpha = 0.5) +
theme_linedraw() +
facet_grid(rows = vars(modname), scales = "free") +
scale_color_viridis_c() +
ggtitle("Model Frameworks vs Final Ms") +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
discrets %>%
tidyr::unnest(cols = "final_fs") %>%
ggplot() +
ggbeeswarm::geom_beeswarm(aes(x = Deme, y = Final_Fis, color = finalcost), alpha = 0.5) +
theme_linedraw() +
facet_grid(rows = vars(modname)) +
scale_color_viridis_c() +
ggtitle("Demes vs Final Fs") +
theme(axis.text.x = element_text(angle = 90))
coordgrid$deme <- as.character(coordgrid$deme )
discretsum <- discrets %>%
tidyr::unnest(cols = "final_fs") %>%
dplyr::group_by(modname, Deme) %>%
dplyr::summarise(
meanFi = mean(Final_Fis),
varFi = var(Final_Fis),
minFi = min(Final_Fis),
maxFi = max(Final_Fis)
) %>%
dplyr::rename(deme = Deme) %>%
dplyr::left_join(., coordgrid, by = "deme")
discretsum %>%
ggplot() +
geom_point(aes(x = longnum, y = latnum,
color = meanFi), size = 5) +
geom_text(aes(x = longnum, y = latnum, label = deme),
color = "#ffffff", size = 3) +
scale_color_viridis_c("Mean DISCs") +
facet_grid(rows = vars(modname)) +
ylim(c(-2, 24)) +
ggtitle("Mean Fis") +
theme_linedraw() +
theme(axis.text = element_blank())
discretsum %>%
ggplot() +
geom_point(aes(x = longnum, y = latnum,
color = varFi), size = 5) +
geom_text(aes(x = longnum, y = latnum, label = deme),
color = "#ffffff", size = 3) +
scale_color_viridis_c("var DISCs") +
facet_grid(rows = vars(modname)) +
ylim(c(-2, 24)) +
ggtitle("Variance Fis") +
theme_linedraw() +
theme(axis.text = element_blank())
discretsum %>%
ggplot() +
geom_point(aes(x = longnum, y = latnum,
color = minFi), size = 5) +
geom_text(aes(x = longnum, y = latnum, label = deme),
color = "#ffffff", size = 3) +
scale_color_viridis_c("Min DISCs") +
facet_grid(rows = vars(modname)) +
ylim(c(-2, 24)) +
ggtitle("Min Fis") +
theme_linedraw() +
theme(axis.text = element_blank())
discretsum %>%
ggplot() +
geom_point(aes(x = longnum, y = latnum,
color = maxFi), size = 5) +
geom_text(aes(x = longnum, y = latnum, label = deme),
color = "#ffffff", size = 3) +
scale_color_viridis_c("max DISCs") +
facet_grid(rows = vars(modname)) +
ylim(c(-2, 24)) +
ggtitle("max Fis") +
theme_linedraw() +
theme(axis.text = element_blank())
Here we will analyze whether wrong geographic distances results in worst cost.
discrets_dist_consider <- discrets %>%
dplyr::filter(modname %in% c("IsoByDist", "badIsoByDist"))
discrets_dist_consider %>%
ggplot() +
geom_boxplot(aes(x = modname, y = finalcost)) +
theme_linedraw() +
ggtitle("Costs over IBD with Correct and Misspecified Distance")